package com.offcn.bigdata.spark.p3

import java.sql.DriverManager

import com.mysql.jdbc.{Connection, Driver, JDBC42PreparedStatement}
import com.offcn.bigdata.common.db.ConnectionPool
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.{SparkConf, SparkContext}

/**
  * foreachPartition --- foreach
  * mapPartition     --- map
  *
  * CREATE TABLE `words` (
      `word` varchar(20) NOT NULL,
      `count` int(11) DEFAULT NULL,
      PRIMARY KEY (`word`)
    ) ENGINE=InnoDB DEFAULT CHARSET=utf8
  */
object _01ForeachPartitionOps {
    def main(args: Array[String]): Unit = {
        val conf = new SparkConf()
                    .setAppName("ForeachPartitionOps")
                    .setMaster("local[*]")
//                    //指定spark的序列化方式为kryo的序列化方式
//                    .set("spark.serializer", classOf[KryoSerializer].getName)
//                    //注册要序列化的类
//                    .registerKryoClasses(
//                        Array(classOf[JDBC42PreparedStatement], classOf[Connection])
//                    )
        val sc = new SparkContext(conf)


        val listRDD = sc.parallelize(List(
            "hello xiao lei lei",
            "hao xiao lei",
            "hello xiao hao chao",
            "hao xiao chao xu shit"
        ), 1)

        val pairs = listRDD.flatMap(_.split("\\s+")).map((_, 1))
//        val ret = pairs.combineByKey[Int](t => t, (c, t) => c + t, (c1, c2) => c1 + c2)
        val ret = pairs.aggregateByKey(0)(_+_,  _+_)

//        save2DB(ret)
//        save2DBByForeach(ret)
//        save2DBByForeachPartition(ret)
//        save2DBByBatch(ret)
//        save2DBByBatchBatch(ret)
        save2DBByPool(ret)
        sc.stop()
    }

    //使用连接池
    def save2DBByPool(rdd: RDD[(String, Int)]): Unit = {
        rdd.foreachPartition(iterator => {
            //这个地方，就在某一个分区内
            //step 2
            val connection = ConnectionPool.getConnection()
            //step 3 INSERT INTO words VALUES('ball', 1), ('little', 1), ('guy', 1)
            val sql =
                """
                  |insert into words(word, `count`) values (?, ?)
                """.stripMargin
            val ps = connection.prepareStatement(sql)
            //创建一个计数器来记录添加了多少条记录 每3条记录提交一次
            var count = 0
            iterator.foreach{case (word, value) => {
                if(count == 3) {
                    count = 0
                    ps.executeBatch()
                    println("-------------------------")
                }
                count += 1
                ps.setString(1, word)
                ps.setInt(2, value)
                ps.addBatch()
            }}
            if(count != 0) {
                ps.executeBatch()
                println("-------------------------")
            }
            ps.close()
            ConnectionPool.release(connection)
        })
    }
    /**
      * 在批处理的基础之上，避免提交的分区数据比较大，使用分批次提交，减轻一次性提交打来的IO和内存负载压力
      * 也就是每隔N条记录提交一次
      */
    def save2DBByBatchBatch(rdd: RDD[(String, Int)]): Unit = {
        rdd.foreachPartition(iterator => {
            println("-------------------------")
            //这个地方，就在某一个分区内
            classOf[Driver]
            //step 2
            val connection = DriverManager.getConnection(
                "jdbc:mysql://localhost:3306/test",
                "root",
                "sorry"
            )
            //step 3 INSERT INTO words VALUES('ball', 1), ('little', 1), ('guy', 1)
            val sql =
                """
                  |insert into words(word, `count`) values (?, ?)
                """.stripMargin
            val ps = connection.prepareStatement(sql)
            //创建一个计数器来记录添加了多少条记录 每3条记录提交一次
            var count = 0
            iterator.foreach{case (word, value) => {
                if(count == 3) {
                    count = 0
                    ps.executeBatch()
                }
                count += 1
                ps.setString(1, word)
                ps.setInt(2, value)
                ps.addBatch()
            }}
            if(count != 0) {
                ps.executeBatch()
            }
            ps.close()
            connection.close()
        })
    }

    /*
        上一个foreachPartition案例中执行的sql都是相同，虽然说一个partition创建了一次connection，
        但是执行了非常多的sql语句，一条记录，要执行一次操作，IO负载还是蛮高的，于是这是就想起批量处理
     */
    def save2DBByBatch(rdd: RDD[(String, Int)]): Unit = {
        rdd.foreachPartition(iterator => {
            println("-------------------------")
            //这个地方，就在某一个分区内
            classOf[Driver]
            //step 2
            val connection = DriverManager.getConnection(
                "jdbc:mysql://localhost:3306/test",
                "root",
                "sorry"
            )
            //step 3 INSERT INTO words VALUES('ball', 1), ('little', 1), ('guy', 1)
            val sql =
                """
                  |insert into words(word, `count`) values (?, ?)
                """.stripMargin
            val ps = connection.prepareStatement(sql)
            //此时的iterator就是某一个分区内的本地持有该分区数据的集合
            iterator.foreach{case (word, count) => {
                ps.setString(1, word)
                ps.setInt(2, count)
                ps.addBatch()
            }}
            ps.executeBatch()
            ps.close()
            connection.close()
        })
    }

    /*
        下面的foreach操作，会让每一条记录创建一次数据库连接，效率太差，
        可以使用foreachPartition，每一个分区创建一次数据库连接
     */
    def save2DBByForeachPartition(rdd: RDD[(String, Int)]): Unit = {
        rdd.foreachPartition(iterator => {
            println("-------------------------")
            //这个地方，就在某一个分区内
            classOf[Driver]
            //step 2
            val connection = DriverManager.getConnection(
                "jdbc:mysql://localhost:3306/test",
                "root",
                "sorry"
            )
            //step 3
            val sql =
                """
                  |insert into words(word, `count`) values (?, ?)
                """.stripMargin
            val ps = connection.prepareStatement(sql)
            //此时的iterator就是某一个分区内的本地持有该分区数据的集合
            iterator.foreach{case (word, count) => {
                ps.setString(1, word)
                ps.setInt(2, count)
                ps.execute()
            }}

            ps.close()
            connection.close()
        })
    }

    /**
      * 其实可以不用解决前面的序列化的问题，只需要将connection、statement在partition所在节点创建即可。即变为了本地计算
      * 不需要跨网络，跨节点传输计算，进而就不需要序列化。
      * @param rdd
      */
    def save2DBByForeach(rdd: RDD[(String, Int)]): Unit = {
        rdd.foreach{case (word, count) => {
            //step 1
            classOf[Driver]
            //step 2
            val connection = DriverManager.getConnection(
                "jdbc:mysql://localhost:3306/test",
                "root",
                "sorry"
            )
            //step 3
            val sql =
                """
                  |insert into words(word, `count`) values (?, ?)
                """.stripMargin
            val ps = connection.prepareStatement(sql)
            ps.setString(1, word)
            ps.setInt(2, count)
            ps.execute()

            ps.close()
            connection.close()
        }}
    }
    /**
      * 将rdd数据写入到mysql中
      *   6的步骤
      *      加载驱动
      *      获取Connection
      *      得到statement执行
      *      执行语句
      *      封装结果集
      *      释放资源
      *
      *  NotSerializableException: com.mysql.jdbc.JDBC42PreparedStatement
      *  原因在于PreparedStatement在driver上创建，rdd.foreach在executor上面执行
      *  所以需要将ps从driver拷贝到executor去执行对应的partition中的数据，所以ps
      *  需要序列化。
      */
    def save2DB(rdd: RDD[(String, Int)]): Unit = {
        //step 1
        classOf[Driver]
        //step 2
        val connection = DriverManager.getConnection(
            "jdbc:mysql://localhost:3306/test",
            "root",
            "sorry"
        )
        //step 3
        val sql =
            """
              |insert into words(word, `count`) values (?, ?)
            """.stripMargin
        val ps = connection.prepareStatement(sql)

        rdd.foreach{case (word, count) => {
            ps.setString(1, word)
            ps.setInt(2, count)
            ps.execute()
        }}
        ps.close()
        connection.close()
    }
}
